{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# SageMaker Data Wrangler Job Notebook\n",
    "This notebook uses the Data Wrangler .flow file to submit a SageMaker Data Wrangler Job\n",
    "with the following steps:\n",
    "\n",
    "* Push Data Wrangler .flow file to S3\n",
    "* Parse the .flow file inputs, and create the argument dictionary to submit to a boto client\n",
    "* Submit the ProcessingJob arguments and wait for Job completion\n",
    "\n",
    "Optionally, the notebook also gives an example of starting a SageMaker XGBoost TrainingJob using\n",
    "the newly processed data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SageMaker Python SDK version 2.x is required\n",
    "import sagemaker\n",
    "import subprocess\n",
    "import sys\n",
    "\n",
    "original_version = sagemaker.__version__\n",
    "if sagemaker.__version__ != \"2.20.0\":\n",
    "    subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"sagemaker==2.20.0\"])\n",
    "    import importlib\n",
    "\n",
    "    importlib.reload(sagemaker)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import time\n",
    "import uuid\n",
    "\n",
    "import boto3\n",
    "import sagemaker"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters\n",
    "The following lists configurable parameters that are used throughout this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# S3 bucket for saving processing job outputs\n",
    "# Feel free to specify a different bucket here if you wish.\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "prefix = \"data_wrangler_flows\"\n",
    "flow_id = f\"{time.strftime('%d-%H-%M-%S', time.gmtime())}-{str(uuid.uuid4())[:8]}\"\n",
    "flow_name = f\"flow-{flow_id}\"\n",
    "flow_uri = f\"s3://{bucket}/{prefix}/{flow_name}.flow\"\n",
    "\n",
    "flow_file_name = \"bias_and_explainability.flow\"\n",
    "\n",
    "iam_role = sagemaker.get_execution_role()\n",
    "\n",
    "container_uri = \"663277389841.dkr.ecr.us-east-1.amazonaws.com/sagemaker-data-wrangler-container:1.1.1\"\n",
    "\n",
    "# Processing Job Resources Configurations\n",
    "# Data wrangler processing job only supports 1 instance.\n",
    "instance_count = 1\n",
    "instance_type = \"ml.m5.4xlarge\"\n",
    "\n",
    "# Processing Job Path URI Information\n",
    "output_prefix = f\"export-{flow_name}/output\"\n",
    "output_path = f\"s3://{bucket}/{output_prefix}\"\n",
    "output_name = \"14039109-2da9-49b4-8eee-df39306c9c47.default\"\n",
    "\n",
    "processing_job_name = f\"data-wrangler-flow-processing-{flow_id}\"\n",
    "\n",
    "processing_dir = \"/opt/ml/processing\"\n",
    "\n",
    "# Modify the variable below to specify the content type to be used for writing each output\n",
    "# Currently supported options are 'CSV' or 'PARQUET', and default to 'CSV'\n",
    "output_content_type = \"CSV\"\n",
    "\n",
    "# URL to use for sagemaker client.\n",
    "# If this is None, boto will automatically construct the appropriate URL to use\n",
    "# when communicating with sagemaker.\n",
    "sagemaker_endpoint_url = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Push Flow to S3\n",
    "Use the following cell to upload the Data Wrangler .flow file to Amazon S3 so that\n",
    "it can be used as an input to the processing job."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load .flow file\n",
    "with open(flow_file_name) as f:\n",
    "    flow = json.load(f)\n",
    "\n",
    "# Upload to S3\n",
    "s3_client = boto3.client(\"s3\")\n",
    "s3_client.upload_file(flow_file_name, bucket, f\"{prefix}/{flow_name}.flow\")\n",
    "\n",
    "print(f\"Data Wrangler Flow notebook uploaded to {flow_uri}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create boto3 Processing Job arguments\n",
    "\n",
    "This notebook submits a Processing Job using boto, which will require an argument dictionary to\n",
    "submit to the boto client. Below, utility methods are defined for creating Processing Job Inputs\n",
    "for the following sources: S3, Athena, and Redshift. Then the argument dictionary is generated\n",
    "using the parsed inputs and job configurations such as instance type."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_flow_notebook_processing_input(base_dir, flow_s3_uri):\n",
    "    return {\n",
    "        \"InputName\": \"flow\",\n",
    "        \"S3Input\": {\n",
    "            \"LocalPath\": f\"{base_dir}/flow\",\n",
    "            \"S3Uri\": flow_s3_uri,\n",
    "            \"S3DataType\": \"S3Prefix\",\n",
    "            \"S3InputMode\": \"File\",\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "def create_s3_processing_input(base_dir, name, dataset_definition):\n",
    "    return {\n",
    "        \"InputName\": name,\n",
    "        \"S3Input\": {\n",
    "            \"LocalPath\": f\"{base_dir}/{name}\",\n",
    "            \"S3Uri\": dataset_definition[\"s3ExecutionContext\"][\"s3Uri\"],\n",
    "            \"S3DataType\": \"S3Prefix\",\n",
    "            \"S3InputMode\": \"File\",\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "def create_redshift_processing_input(base_dir, name, dataset_definition):\n",
    "    return {\n",
    "        \"InputName\": name,\n",
    "        \"DatasetDefinition\": {\n",
    "            \"RedshiftDatasetDefinition\": {\n",
    "                \"ClusterId\": dataset_definition[\"clusterIdentifier\"],\n",
    "                \"Database\": dataset_definition[\"database\"],\n",
    "                \"DbUser\": dataset_definition[\"dbUser\"],\n",
    "                \"QueryString\": dataset_definition[\"queryString\"],\n",
    "                \"ClusterRoleArn\": dataset_definition[\"unloadIamRole\"],\n",
    "                \"OutputS3Uri\": f'{dataset_definition[\"s3OutputLocation\"]}{name}/',\n",
    "                \"OutputFormat\": dataset_definition[\"outputFormat\"].upper(),\n",
    "            },\n",
    "            \"LocalPath\": f\"{base_dir}/{name}\",\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "def create_athena_processing_input(base_dir, name, dataset_definition):\n",
    "    return {\n",
    "        \"InputName\": name,\n",
    "        \"DatasetDefinition\": {\n",
    "            \"AthenaDatasetDefinition\": {\n",
    "                \"Catalog\": dataset_definition[\"catalogName\"],\n",
    "                \"Database\": dataset_definition[\"databaseName\"],\n",
    "                \"QueryString\": dataset_definition[\"queryString\"],\n",
    "                \"OutputS3Uri\": f'{dataset_definition[\"s3OutputLocation\"]}{name}/',\n",
    "                \"OutputFormat\": dataset_definition[\"outputFormat\"].upper(),\n",
    "            },\n",
    "            \"LocalPath\": f\"{base_dir}/{name}\",\n",
    "        },\n",
    "    }\n",
    "\n",
    "\n",
    "def create_processing_inputs(processing_dir, flow, flow_uri):\n",
    "    \"\"\"Helper function for creating processing inputs\n",
    "    :param flow: loaded data wrangler flow notebook\n",
    "    :param flow_uri: S3 URI of the data wrangler flow notebook\n",
    "    \"\"\"\n",
    "    processing_inputs = []\n",
    "    flow_processing_input = create_flow_notebook_processing_input(processing_dir, flow_uri)\n",
    "    processing_inputs.append(flow_processing_input)\n",
    "\n",
    "    for node in flow[\"nodes\"]:\n",
    "        if \"dataset_definition\" in node[\"parameters\"]:\n",
    "            data_def = node[\"parameters\"][\"dataset_definition\"]\n",
    "            name = data_def[\"name\"]\n",
    "            source_type = data_def[\"datasetSourceType\"]\n",
    "\n",
    "            if source_type == \"S3\":\n",
    "                s3_processing_input = create_s3_processing_input(processing_dir, name, data_def)\n",
    "                processing_inputs.append(s3_processing_input)\n",
    "            elif source_type == \"Athena\":\n",
    "                athena_processing_input = create_athena_processing_input(processing_dir, name, data_def)\n",
    "                processing_inputs.append(athena_processing_input)\n",
    "            elif source_type == \"Redshift\":\n",
    "                redshift_processing_input = create_redshift_processing_input(processing_dir, name, data_def)\n",
    "                processing_inputs.append(redshift_processing_input)\n",
    "            else:\n",
    "                raise ValueError(f\"{source_type} is not supported for Data Wrangler Processing.\")\n",
    "    return processing_inputs\n",
    "\n",
    "\n",
    "def create_container_arguments(output_name, output_content_type):\n",
    "    output_config = {output_name: {\"content_type\": output_content_type}}\n",
    "    return [f\"--output-config '{json.dumps(output_config)}'\"]\n",
    "\n",
    "\n",
    "# Create Processing Job Arguments\n",
    "processing_job_arguments = {\n",
    "    \"AppSpecification\": {\n",
    "        \"ContainerArguments\": create_container_arguments(output_name, output_content_type),\n",
    "        \"ImageUri\": container_uri,\n",
    "    },\n",
    "    \"ProcessingInputs\": create_processing_inputs(processing_dir, flow, flow_uri),\n",
    "    \"ProcessingOutputConfig\": {\n",
    "        \"Outputs\": [\n",
    "            {\n",
    "                \"OutputName\": output_name,\n",
    "                \"S3Output\": {\n",
    "                    \"S3Uri\": output_path,\n",
    "                    \"LocalPath\": os.path.join(processing_dir, \"output\"),\n",
    "                    \"S3UploadMode\": \"EndOfJob\",\n",
    "                },\n",
    "            },\n",
    "        ],\n",
    "    },\n",
    "    \"ProcessingJobName\": processing_job_name,\n",
    "    \"ProcessingResources\": {\n",
    "        \"ClusterConfig\": {\n",
    "            \"InstanceCount\": instance_count,\n",
    "            \"InstanceType\": instance_type,\n",
    "            \"VolumeSizeInGB\": 30,\n",
    "        }\n",
    "    },\n",
    "    \"RoleArn\": iam_role,\n",
    "    \"StoppingCondition\": {\n",
    "        \"MaxRuntimeInSeconds\": 86400,\n",
    "    },\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Start ProcessingJob\n",
    "Now, the Processing Job is submitted to a boto client. The status of the processing job is\n",
    "monitored with the boto client, and this notebook waits until the job is no longer 'InProgress'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sagemaker_client = boto3.client(\"sagemaker\", endpoint_url=sagemaker_endpoint_url)\n",
    "create_response = sagemaker_client.create_processing_job(**processing_job_arguments)\n",
    "\n",
    "status = sagemaker_client.describe_processing_job(ProcessingJobName=processing_job_name)\n",
    "\n",
    "while status[\"ProcessingJobStatus\"] == \"InProgress\":\n",
    "    status = sagemaker_client.describe_processing_job(ProcessingJobName=processing_job_name)\n",
    "    print(status[\"ProcessingJobStatus\"])\n",
    "    time.sleep(60)\n",
    "\n",
    "print(status)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Kick off SageMaker Training Job (Optional)\n",
    "Data Wrangler is a SageMaker tool for processing data to be used for Machine Learning. Now that\n",
    "the data has been processed, users will want to train a model using the data. The following shows\n",
    "an example of doing so using a popular algorithm XGBoost.\n",
    "\n",
    "It is important to note that the following XGBoost objective ['binary', 'regression',\n",
    "'multiclass'], hyperparameters, or content_type may not be suitable for the output data, and will\n",
    "require changes to train a proper model. Furthermore, for CSV training, the algorithm assumes that\n",
    "the target variable is in the first column. For more information on SageMaker XGBoost, please see\n",
    "https://docs.aws.amazon.com/sagemaker/latest/dg/xgboost.html.\n",
    "\n",
    "### Find Training Data path\n",
    "The below demonstrates how to recursively search the output directory to find the data location."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_client = boto3.client(\"s3\")\n",
    "list_response = s3_client.list_objects_v2(Bucket=bucket, Prefix=output_prefix)\n",
    "\n",
    "training_path = None\n",
    "\n",
    "for content in list_response[\"Contents\"]:\n",
    "    if \"_SUCCESS\" not in content[\"Key\"]:\n",
    "        training_path = content[\"Key\"]\n",
    "\n",
    "print(training_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, the Training Job hyperparameters are set. For more information on XGBoost Hyperparameters,\n",
    "see https://xgboost.readthedocs.io/en/latest/parameter.html."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "region = boto3.Session().region_name\n",
    "container = sagemaker.image_uris.retrieve(\"xgboost\", region, \"1.2-1\")\n",
    "hyperparameters = {\n",
    "    \"max_depth\": \"5\",\n",
    "    \"objective\": \"reg:squarederror\",\n",
    "    \"num_round\": \"10\",\n",
    "}\n",
    "train_content_type = \"application/x-parquet\" if output_content_type.upper() == \"PARQUET\" else \"text/csv\"\n",
    "train_input = sagemaker.inputs.TrainingInput(\n",
    "    s3_data=f\"s3://{bucket}/{training_path}\",\n",
    "    content_type=train_content_type,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The TrainingJob configurations are set using the SageMaker Python SDK Estimator, and which is fit\n",
    "using the training data from the ProcessingJob that was run earlier."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "estimator = sagemaker.estimator.Estimator(\n",
    "    container,\n",
    "    iam_role,\n",
    "    hyperparameters=hyperparameters,\n",
    "    instance_count=1,\n",
    "    instance_type=\"ml.m5.2xlarge\",\n",
    ")\n",
    "estimator.fit({\"train\": train_input})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Cleanup\n",
    "Uncomment the following code cell to revert the SageMaker Python SDK to the original version used\n",
    "before running this notebook. This notebook upgrades the SageMaker Python SDK to 2.x, which may\n",
    "cause other example notebooks to break. To learn more about the changes introduced in the\n",
    "SageMaker Python SDK 2.x update, see\n",
    "[Use Version 2.x of the SageMaker Python SDK.](https://sagemaker.readthedocs.io/en/stable/v2.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# _ = subprocess.check_call(\n",
    "#     [sys.executable, \"-m\", \"pip\", \"install\", f\"sagemaker=={original_version}\"]\n",
    "# )"
   ]
  }
 ],
 "metadata": {
  "instance_type": "ml.t3.medium",
  "kernelspec": {
   "display_name": "Python 3 (Data Science)",
   "language": "python",
   "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/datascience-1.0"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
